import numpy as np
from matplotlib import pyplot as plt
from glob import glob
from torchvision.transforms import transforms
from PIL import Image
import torch
import matplotlib
matplotlib.use('Agg')

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])


@torch.no_grad()
def get_demo_predictions(args, device, model):
    figs = get_demo_predictions_with_mask(args, model)
    return {"image_%s" % i: fig for i, fig in enumerate(figs)}


def show_image(image, ax):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    ax.imshow(torch.clip((image * imagenet_std +
              imagenet_mean) * 255, 0, 255).int())
    ax.axis('off')
    return


@torch.no_grad()
def get_demo_predictions_with_mask(args, model):
    imgs = []
    if 'cifar' not in args.data_path:
        t = transforms.Compose([
            transforms.Resize(256, interpolation=3),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        for p in glob("/home/group/ilsvrc/val/n02916936/*")[:8]:
            with open(p, 'rb') as f:
                png = Image.open(f).convert('RGBA')
                background = Image.new('RGBA', png.size, (255, 255, 255))
                img = Image.alpha_composite(background, png).convert('RGB').resize((args.input_size, args.input_size),
                                                                                   resample=Image.LANCZOS)
                img = t(img)
                imgs.append(img)
    else:
        t = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        for p in glob("images/*")[:8]:
            with open(p, 'rb') as f:
                png = Image.open(f)
                img = png.convert('RGB')
                img = t(img)
                imgs.append(img)
    imgs = torch.stack(imgs, dim=0)
    x = imgs.cuda(non_blocking=True)
    loss, y, mask, _ = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    mask = mask.detach()
    # (N, H*W, p*p*3)
    mask = mask.unsqueeze(-1).repeat(1, 1,
                                     model.patch_embed.patch_size[0] ** 2 * 3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', x).to(mask)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    # plt.figure()
    figs = []
    for k in range(0, len(imgs), 4):
        fig, ax = plt.subplots(4, 4, figsize=(10, 10))
        plt.subplots_adjust(wspace=0, hspace=0)
        for i in range(len(imgs[k:k + 4])):
            show_image(x[k+i], ax[i, 0])
            show_image(im_masked[k+i], ax[i, 1])
            show_image(y[k+i], ax[i, 2])
            show_image(im_paste[k+i], ax[i, 3])

            for j in range(4):
                ax[i, j].set_xticklabels([])
                ax[i, j].set_yticklabels([])
                ax[i, j].set_aspect('equal')
        figs.append(fig)

    # plt.show()

    return figs
